#pragma once

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <forward_list>
#include <utility>
#include <vector>

#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

namespace torch {
namespace profiler {
namespace impl {

// ============================================================================
// == AppendOnlyList ==========================================================
// ============================================================================
//   During profiling, we have a very predictable access pattern: we only
// append to the end of the container. We can specialize and outperform both
// std::vector (which must realloc) and std::deque (which performs a double
// indirection), and this class of operation is sufficiently important to the
// profiling hot path to warrant specializing:
//   https://godbolt.org/z/rTjozf1c4
//   https://quick-bench.com/q/mmfuu71ogwaiULDCJyHdKnHZms4    (Prototype #1, int)
//   https://quick-bench.com/q/5vWDW6jjdXVdoffev2zst8D09no    (Prototype #1, int pair)
//   https://quick-bench.com/q/IfEkfAQMeJSNBA52xtMP6Agcl-Q    (Prototype #2, int pair)
//   https://quick-bench.com/q/wJV2lKmuXL4XyGJzcI5hs4gEHFg    (Prototype #3, int pair)
//   https://quick-bench.com/q/xiO8ZaBEkYRYUA9dFrMuPLlW9fo    (Full impl, int pair)
// AppendOnlyList has 2x lower emplace overhead compared to more generic STL
// containers.
//
//   The optimal value of `ChunkSize` will vary by use case, but testing shows
// that a value of 1024 does a good job amortizing the `malloc` cost of growth.
// Performance drops off for larger values, so testing on a case-by-case basis
// is recommended if performance is absolutely critical.

template <typename T, size_t ChunkSize>
class AppendOnlyList {
 public:
  using array_t = std::array<T, ChunkSize>;
  static_assert(ChunkSize > 0, "Block cannot be empty.");

  AppendOnlyList() : buffer_last_{buffer_.before_begin()} {}
  AppendOnlyList(const AppendOnlyList&) = delete;
  AppendOnlyList& operator=(const AppendOnlyList&) = delete;

  size_t size() const {
    return n_blocks_ * ChunkSize + (size_t)(next_ - end_);
  }

  template <class... Args>
  T* emplace_back(Args&&... args) {
    maybe_grow();
    *next_ = {std::forward<Args>(args)...};
    return next_++;
  }

  void clear() {
    buffer_.clear();
    buffer_last_ = buffer_.begin();
    n_blocks_ = 0;
    next_ = nullptr;
    end_ = nullptr;
  }

  struct Iterator {
    using iterator_category = std::forward_iterator_tag;
    using difference_type   = std::ptrdiff_t;
    using value_type        = T;
    using pointer           = T*;
    using reference         = T&;

    Iterator(std::forward_list<array_t>& buffer, const size_t size)
      : block_{buffer.begin()}, size_{size} {}

    // End iterator.
    Iterator() = default;

    reference operator*() const { return *current_ptr(/*checked=*/true); }
    pointer operator->() { return current_ptr(/*checked=*/true); }

    // Prefix increment
    Iterator& operator++() {
      if (!(++current_ % ChunkSize)) {
        block_++;
      }
      return *this;
    }

    // Postfix increment
    Iterator operator++(int) { Iterator tmp = *this; ++(*this); return tmp; }

    friend bool operator==(const Iterator& a, const Iterator& b) {
      return a.current_ptr() == b.current_ptr();
    }
    friend bool operator!=(const Iterator& a, const Iterator& b) {
      return a.current_ptr() != b.current_ptr();
    }

    std::pair<array_t*, size_t> address() const {
      if (current_ >= size_){
        return {nullptr, 0};
      }
      return {&(*block_), current_ % ChunkSize};
    }

   private:
    T* current_ptr(bool checked = false) const {
      auto a = address();
      if (a.first == nullptr) {
        TORCH_INTERNAL_ASSERT(!checked, "Invalid access on AppendOnlyList.");
        return nullptr;
      }
      return a.first->data() + a.second;
    }

    typename std::forward_list<array_t>::iterator block_;
    size_t current_ {0};
    size_t size_ {0};
  };

  Iterator begin() { return Iterator(buffer_, size()); }
  Iterator end()   { return Iterator(); }
  // TODO: cbegin and cend()

// TODO: make private
 protected:
  void maybe_grow() {
    if (C10_UNLIKELY(next_ == end_)) {
      buffer_last_ = buffer_.emplace_after(buffer_last_);
      n_blocks_++;
      next_ = buffer_last_->data();
      end_ = next_ + ChunkSize;
    }
  }

  std::forward_list<array_t> buffer_;

  // We maintain a pointer to the last element of `buffer_` so that we can
  // insert at the end in O(1) time.
  typename std::forward_list<array_t>::iterator buffer_last_;
  size_t n_blocks_ {0};
  T* next_ {nullptr};
  T* end_ {nullptr};
};

} // namespace impl
} // namespace profiler
} // namespace torch
